{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YONnGjpAYUdU"
      },
      "source": [
        "\n",
        "\u003ca href=\"https://colab.research.google.com/github/google-research/bigbird/blob/master/bigbird/classifier/imdb.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zrtR2urJV3ST"
      },
      "source": [
        "##### Copyright 2020 The BigBird Authors\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xyasTfa-LVLe"
      },
      "outputs": [],
      "source": [
        "# Copyright 2020 The BigBird Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fcZZRDx505hq"
      },
      "source": [
        "## Set Up"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N94UyOdA0mCO"
      },
      "outputs": [],
      "source": [
        "!pip install git+https://github.com/google-research/bigbird.git -q"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0irPwcbBYvDV"
      },
      "outputs": [],
      "source": [
        "from bigbird.core import flags\n",
        "from bigbird.core import modeling\n",
        "from bigbird.core import utils\n",
        "from bigbird.classifier import run_classifier\n",
        "import tensorflow.compat.v2 as tf\n",
        "import tensorflow_datasets as tfds\n",
        "from tqdm import tqdm\n",
        "import sys\n",
        "\n",
        "FLAGS = flags.FLAGS\n",
        "if not hasattr(FLAGS, \"f\"): flags.DEFINE_string(\"f\", \"\", \"\")\n",
        "FLAGS(sys.argv)\n",
        "\n",
        "tf.enable_v2_behavior()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AJexg2zsxfHo"
      },
      "source": [
        "## Set options"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rph2sJ75kBNA"
      },
      "outputs": [],
      "source": [
        "FLAGS.data_dir = \"tfds://imdb_reviews/plain_text\"\n",
        "FLAGS.attention_type = \"block_sparse\"\n",
        "FLAGS.max_encoder_length = 4096  # reduce for quicker demo on free colab\n",
        "FLAGS.learning_rate = 1e-5\n",
        "FLAGS.num_train_steps = 2000\n",
        "FLAGS.attention_probs_dropout_prob = 0.0\n",
        "FLAGS.hidden_dropout_prob = 0.0\n",
        "FLAGS.use_gradient_checkpointing = True\n",
        "FLAGS.vocab_model_file = \"gpt2\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zuxI3V_3j57Y"
      },
      "outputs": [],
      "source": [
        "bert_config = flags.as_dictionary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kRF4TUEQxjXJ"
      },
      "source": [
        "## Define classification model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J3yNdo5toQwq"
      },
      "outputs": [],
      "source": [
        "model = modeling.BertModel(bert_config)\n",
        "headl = run_classifier.ClassifierLossLayer(\n",
        "        bert_config[\"hidden_size\"], bert_config[\"num_labels\"],\n",
        "        bert_config[\"hidden_dropout_prob\"],\n",
        "        utils.create_initializer(bert_config[\"initializer_range\"]),\n",
        "        name=bert_config[\"scope\"]+\"/classifier\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DXOY78vbqHX9"
      },
      "outputs": [],
      "source": [
        "@tf.function(experimental_compile=True)\n",
        "def fwd_bwd(features, labels):\n",
        "  with tf.GradientTape() as g:\n",
        "    _, pooled_output = model(features, training=True)\n",
        "    loss, log_probs = headl(pooled_output, labels, True)\n",
        "  grads = g.gradient(loss, model.trainable_weights+headl.trainable_weights)\n",
        "  return loss, log_probs, grads"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DzoTMyQlxsRo"
      },
      "source": [
        "## Dataset pipeline"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 5770,
          "status": "ok",
          "timestamp": 1607595313569,
          "user": {
            "displayName": "Manzil Zaheer",
            "photoUrl": "",
            "userId": "06259716656099187509"
          },
          "user_tz": 480
        },
        "id": "Oo-NQTDZZs51",
        "outputId": "ed2a0713-e06a-442f-a188-191d1fdc494d"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\n"
          ]
        }
      ],
      "source": [
        "train_input_fn = run_classifier.input_fn_builder(\n",
        "        data_dir=FLAGS.data_dir,\n",
        "        vocab_model_file=FLAGS.vocab_model_file,\n",
        "        max_encoder_length=FLAGS.max_encoder_length,\n",
        "        substitute_newline=FLAGS.substitute_newline,\n",
        "        is_training=True)\n",
        "dataset = train_input_fn({'batch_size': 8})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 1527,
          "status": "ok",
          "timestamp": 1607595315103,
          "user": {
            "displayName": "Manzil Zaheer",
            "photoUrl": "",
            "userId": "06259716656099187509"
          },
          "user_tz": 480
        },
        "id": "hRvmfaNUi-V5",
        "outputId": "18578022-0344-4d01-cb2d-048f0c4f0d78"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "(\u003ctf.Tensor: shape=(8, 4096), dtype=int32, numpy=\n",
            "array([[   65,   733,   474, ...,     0,     0,     0],\n",
            "       [   65,   415, 26500, ...,     0,     0,     0],\n",
            "       [   65,   484, 20677, ...,     0,     0,     0],\n",
            "       ...,\n",
            "       [   65,   418,  1150, ...,     0,     0,     0],\n",
            "       [   65,  9271,  5714, ...,     0,     0,     0],\n",
            "       [   65,  8301,   113, ...,     0,     0,     0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(8,), dtype=int32, numpy=array([0, 1, 1, 1, 1, 0, 1, 0], dtype=int32)\u003e)\n",
            "(\u003ctf.Tensor: shape=(8, 4096), dtype=int32, numpy=\n",
            "array([[  65, 1182,  358, ...,    0,    0,    0],\n",
            "       [  65,  871,  419, ...,    0,    0,    0],\n",
            "       [  65,  415, 1908, ...,    0,    0,    0],\n",
            "       ...,\n",
            "       [  65,  484, 1722, ...,    0,    0,    0],\n",
            "       [  65,  876, 1154, ...,    0,    0,    0],\n",
            "       [  65,  415, 1092, ...,    0,    0,    0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(8,), dtype=int32, numpy=array([0, 1, 0, 0, 1, 0, 0, 1], dtype=int32)\u003e)\n",
            "(\u003ctf.Tensor: shape=(8, 4096), dtype=int32, numpy=\n",
            "array([[   65,   456,   382, ...,     0,     0,     0],\n",
            "       [   65,   484, 34679, ...,     0,     0,     0],\n",
            "       [   65, 16224,   112, ...,     0,     0,     0],\n",
            "       ...,\n",
            "       [   65,   484,  3822, ...,     0,     0,     0],\n",
            "       [   65,   484,  2747, ...,     0,     0,     0],\n",
            "       [   65,   415,  1208, ...,     0,     0,     0]], dtype=int32)\u003e, \u003ctf.Tensor: shape=(8,), dtype=int32, numpy=array([0, 0, 0, 0, 1, 0, 1, 0], dtype=int32)\u003e)\n"
          ]
        }
      ],
      "source": [
        "# inspect at a few examples\n",
        "for ex in dataset.take(3):\n",
        "  print(ex)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lYCyGH56zOOU"
      },
      "source": [
        "## (Optionally) Check outputs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 458,
          "status": "ok",
          "timestamp": 1607595411541,
          "user": {
            "displayName": "Manzil Zaheer",
            "photoUrl": "",
            "userId": "06259716656099187509"
          },
          "user_tz": 480
        },
        "id": "5uQOwyGQzRKt",
        "outputId": "6db22a02-3689-4b86-e6ed-b67eabbfc743"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Loss:  0.6977416\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\n"
          ]
        }
      ],
      "source": [
        "loss, log_probs, grads = fwd_bwd(ex[0], ex[1])\n",
        "print('Loss: ', loss.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Qz_LdCCdyDCR"
      },
      "source": [
        "## (Optionally) Load pretrained model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 36637,
          "status": "ok",
          "timestamp": 1607595448644,
          "user": {
            "displayName": "Manzil Zaheer",
            "photoUrl": "",
            "userId": "06259716656099187509"
          },
          "user_tz": 480
        },
        "id": "rRa2dD1RzLN4",
        "outputId": "225e476b-2314-428a-b4ee-d267fb934a70"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 199/199 [00:34\u003c00:00,  4.94it/s]\n"
          ]
        }
      ],
      "source": [
        "ckpt_path = 'gs://bigbird-transformer/pretrain/bigbr_base/model.ckpt-0'\n",
        "ckpt_reader = tf.compat.v1.train.NewCheckpointReader(ckpt_path)\n",
        "model.set_weights([ckpt_reader.get_tensor(v.name[:-2]) for v in tqdm(model.trainable_weights, position=0)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r6-BziYxzL3U"
      },
      "source": [
        "## Train"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 5080359,
          "status": "ok",
          "timestamp": 1607600529015,
          "user": {
            "displayName": "Manzil Zaheer",
            "photoUrl": "",
            "userId": "06259716656099187509"
          },
          "user_tz": 480
        },
        "id": "IWjkDvu9k7ie",
        "outputId": "67dcf3e1-c126-4291-90bc-da71b8c07c52"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Loss = 0.7094929218292236  Accuracy = 0.5"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "  0%|          | 0/2000 [00:06\u003c1:59:12,  3.57it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Loss = 0.47779741883277893  Accuracy = 0.7558900713920593"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            " 10%|█         | 200/2000 [11:26\u003c1:48:08,  3.60it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Loss = 0.3703668415546417  Accuracy = 0.8318414092063904"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            " 20%|██        | 400/2000 [23:52\u003c1:35:17,  3.58it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Loss = 0.3130376636981964  Accuracy = 0.8654822111129761"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            " 30%|███       | 600/2000 [35:18\u003c1:24:58,  3.60it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Loss = 0.2806303799152374  Accuracy = 0.8822692632675171"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            " 40%|████      | 800/2000 [47:44\u003c1:12:41,  3.60it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Loss = 0.2649693191051483  Accuracy = 0.8901362419128418"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            " 50%|█████     | 1000/2000 [59:10\u003c59:03,  3.58it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Loss = 0.25240564346313477  Accuracy = 0.8967254161834717"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            " 60%|██████    | 1200/2000 [1:11:36\u003c47:43,  3.60it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Loss = 0.24363534152507782  Accuracy = 0.901509702205658"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            " 70%|███████   | 1400/2000 [1:23:02\u003c35:20,  3.58it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Loss = 0.23414449393749237  Accuracy = 0.9062696695327759"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            " 80%|████████  | 1600/2000 [1:35:30\u003c23:23,  3.60it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Loss = 0.22541514039039612  Accuracy = 0.9101060628890991"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            " 90%|█████████ | 1800/2000 [1:46:05\u003c11:34,  3.60it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Loss = 0.2210962176322937  Accuracy = 0.9125439524650574"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 2000/2000 [1:59:39\u003c00:00,  3.58it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\n"
          ]
        }
      ],
      "source": [
        "opt = tf.keras.optimizers.Adam(FLAGS.learning_rate)\n",
        "train_loss = tf.keras.metrics.Mean(name='train_loss')\n",
        "train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')\n",
        "\n",
        "for i, ex in enumerate(tqdm(dataset.take(FLAGS.num_train_steps), position=0)):\n",
        "  loss, log_probs, grads = fwd_bwd(ex[0], ex[1])\n",
        "  opt.apply_gradients(zip(grads, model.trainable_weights+headl.trainable_weights))\n",
        "  train_loss(loss)\n",
        "  train_accuracy(tf.one_hot(ex[1], 2), log_probs)\n",
        "  if i% 200 == 0:\n",
        "    print('Loss = {}  Accuracy = {}'.format(train_loss.result().numpy(), train_accuracy.result().numpy()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mXjkdtyKMAbP"
      },
      "source": [
        "## Eval"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cc3OatxlMCk3"
      },
      "outputs": [],
      "source": [
        "@tf.function(experimental_compile=True)\n",
        "def fwd_only(features, labels):\n",
        "  _, pooled_output = model(features, training=False)\n",
        "  loss, log_probs = headl(pooled_output, labels, False)\n",
        "  return loss, log_probs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-Mq_xhMzef42"
      },
      "outputs": [],
      "source": [
        "eval_input_fn = run_classifier.input_fn_builder(\n",
        "        data_dir=FLAGS.data_dir,\n",
        "        vocab_model_file=FLAGS.vocab_model_file,\n",
        "        max_encoder_length=FLAGS.max_encoder_length,\n",
        "        substitute_newline=FLAGS.substitute_newline,\n",
        "        is_training=False)\n",
        "eval_dataset = eval_input_fn({'batch_size': 8})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 3263,
          "status": "ok",
          "timestamp": 1607617729500,
          "user": {
            "displayName": "Manzil Zaheer",
            "photoUrl": "",
            "userId": "06259716656099187509"
          },
          "user_tz": 480
        },
        "id": "rqPN4R8kerUG",
        "outputId": "194f8765-f13d-46f9-f7fc-0b4b54c9e9d5"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Loss = 0.16173037886619568  Accuracy = 0.9459513425827026100"
          ]
        }
      ],
      "source": [
        "eval_loss = tf.keras.metrics.Mean(name='eval_loss')\n",
        "eval_accuracy = tf.keras.metrics.CategoricalAccuracy(name='eval_accuracy')\n",
        "\n",
        "for ex in tqdm(eval_dataset, position=0):\n",
        "  loss, log_probs = fwd_only(ex[0], ex[1])\n",
        "  eval_loss(loss)\n",
        "  eval_accuracy(tf.one_hot(ex[1], 2), log_probs)\n",
        "print('Loss = {}  Accuracy = {}'.format(eval_loss.result().numpy(), eval_accuracy.result().numpy()))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BvEFgoXJxQOa"
      },
      "outputs": [],
      "source": [
        ""
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {},
      "name": "BigBirdGPU.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
